import argparse

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import DataSet
from models import model


def test(net, loader):
    net.eval()
    running_acc = 0
    count = 0
    for i, data in enumerate(loader):

        audio, label = data['video'].cuda(), data['label'].cuda()
        y_hat = net(audio)
        count += y_hat.shape[0]
        running_acc += (y_hat.argmax(1) == label).sum().item()
    print("Test ACC:{}".format(running_acc / count))
    return running_acc / count

if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--train_dataset', type=str, default='Train/',
                        help='root path of dataset')
    parser.add_argument('--test_dataset', type=str, default='Test/',
                        help='root path of dataset')
    parser.add_argument('--nb-frames', type=int, default=32, help='frames of each video')
    parser.add_argument('--batch-size', type=int, default=24)
    parser.add_argument('--epoch', type=int, default=40)
    parser.add_argument('--lr', type=float, default=0.001, help='learning rate, defaults to 1e-3')
    parser.add_argument('--nb-workers', type=int, default=16, help='Number of workers for dataloader.')
    parser.add_argument('--nb-class', type=int, default=309, help='Number of class for dataset.')
    args = parser.parse_args()

    criterion = nn.CrossEntropyLoss()
    net = model.VideoModel(args.nb_class).cuda()
    VDataset = DataSet.VideoDataset(args.train_dataset, args.nb_frames)
    VDataloader = DataLoader(VDataset, batch_size=args.batch_size, num_workers=args.nb_workers, shuffle=True,
                             drop_last=False, pin_memory=True)

    test_VDataset = DataSet.VideoTestDataset(args.test_dataset, args.nb_frames)
    test_VDataloader = DataLoader(test_VDataset, batch_size=args.batch_size, num_workers=args.nb_workers, shuffle=True,
                             drop_last=False, pin_memory=True)

    print("DataSet:{}, DataLoader:{}".format(len(VDataset), len(VDataloader)))
    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
    # optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    for epoch in range(args.epoch):
        if epoch == 0:
            test(net, test_VDataloader)
        net.train()
        running_loss = 0.
        count = 0
        correct = 0
        for _, data in enumerate(VDataloader):
            audio, label = data['video'].cuda(), data['label'].cuda()
            y_hat = net(audio)
            loss = criterion(y_hat, label)
            running_loss += loss.item()
            count += 1
            correct += y_hat.max(1)[1].eq(label).sum().item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print("video step size 10 epoch:{}, running loss:{}, running acc:{}".format(epoch, running_loss/count, correct/len(VDataset)))
        test_acc = test(net, test_VDataloader)
        scheduler.step()
        # torch.save(net.state_dict(), "ckpt_full2/video_{}_{}.pt".format(epoch, test_acc))


